{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "85d99234",
   "metadata": {},
   "source": [
    "Setting Default Cache to specified path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5a9ec2ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TRANSFORMERS_CACHE'] = \"D:/transformer_cache/\"\n",
    "os.environ['HF_DATASETS_CACHE'] = \"D:/transformer_cache/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0400b39",
   "metadata": {},
   "source": [
    "Importing the cnn_dailymail dataset and printing its columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6b8e9ab4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Features: ['article', 'highlights', 'id']\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"abisee/cnn_dailymail\", \"3.0.0\")\n",
    "print(f\"Features: {dataset['train'].column_names}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13b40b45",
   "metadata": {},
   "source": [
    "Printing first 500 characters of the article at index [1] and its corresponding summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "51f21dd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Article (excerpt of 500 characters, total length: 4051):\n",
      "\n",
      "Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events. Here, Soledad O'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the \"forgotten floor,\" where many mentally ill inmates are housed in Miami before trial. MIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the \"forgotten floor.\" Here, inmates with the most s\n",
      "\n",
      "Summary (length: 281):\n",
      "Mentally ill inmates in Miami are housed on the \"forgotten floor\"\n",
      "Judge Steven Leifman says most are there as a result of \"avoidable felonies\"\n",
      "While CNN tours facility, patient shouts: \"I am the son of the president\"\n",
      "Leifman says the system is unjust and he's fighting for change .\n"
     ]
    }
   ],
   "source": [
    "sample = dataset[\"train\"][1]\n",
    "print(f\"\"\"\n",
    "Article (excerpt of 500 characters, total length: {len(sample[\"article\"])}):\n",
    "\"\"\")\n",
    "print(sample[\"article\"][:500])\n",
    "print(f'\\nSummary (length: {len(sample[\"highlights\"])}):')\n",
    "print(sample[\"highlights\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d09acfc1",
   "metadata": {},
   "source": [
    "We are restricting the input size to 2000 characters and we store a sample input text at index [1] in \"sample text\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "614ddadd",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_text = dataset[\"train\"][1][\"article\"][:2000]\n",
    "# We'll collect the generated summaries of each model in a dictionary\n",
    "summaries = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c76e0da",
   "metadata": {},
   "source": [
    "Creating our tokenizer that splits the input text into sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6a7bf2e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to\n",
      "[nltk_data]     C:\\Users\\USER\\AppData\\Roaming\\nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "['The U.S. are a country.', 'The U.N. is an organization.']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import nltk\n",
    "from nltk.tokenize import sent_tokenize\n",
    "nltk.download(\"punkt\")\n",
    "string = \"The U.S. are a country. The U.N. is an organization.\"\n",
    "sent_tokenize(string)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e788a0f6",
   "metadata": {},
   "source": [
    "We use the first three sentences as our summary baseline to compare the generated summaries against"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d2107370",
   "metadata": {},
   "outputs": [],
   "source": [
    "def three_sentence_summary(text):\n",
    "    return \"\\n\".join(sent_tokenize(text)[:3])\n",
    "summaries[\"baseline\"] = three_sentence_summary(sample_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ff7353e",
   "metadata": {},
   "source": [
    "Loading the GPT-2 model. The model can be used to generate summaries by appending a TL;DR at the end of the input text, which is done in the below code. \n",
    "Max generated summary length is set to 512. clean_up_tokenization removes extra spaces generated during tokenization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5fa93b60",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\USER\\anaconda3\\lib\\site-packages\\transformers\\utils\\hub.py:127: FutureWarning: Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.\n",
      "  warnings.warn(\n",
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n",
      "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n",
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline, set_seed\n",
    "set_seed(42)\n",
    "pipe = pipeline(\"text-generation\", model=\"gpt2-xl\")\n",
    "gpt2_query = sample_text + \"\\nTL;DR:\\n\"\n",
    "pipe_out = pipe(gpt2_query, max_length=512, clean_up_tokenization_spaces=True)\n",
    "summaries[\"gpt2\"] = \"\\n\".join(sent_tokenize(pipe_out[0][\"generated_text\"][len(gpt2_query) :]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57a98dbc",
   "metadata": {},
   "source": [
    "Loading the T5 model and creating a smiliar pipeline.\n",
    "T5 is a text to text model and can be used for summarization by providing the input as \"summarize: <ARTICLE>\". We can directly laod T5 for summarization using pipeline() without worrying about formatting the inputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a427e48f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
     ]
    }
   ],
   "source": [
    "pipe = pipeline(\"summarization\", model=\"t5-large\")\n",
    "pipe_out = pipe(sample_text)\n",
    "summaries[\"t5\"] = \"\\n\".join(sent_tokenize(pipe_out[0][\"summary_text\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb9cf177",
   "metadata": {},
   "source": [
    "Loading the BART model and creating a similiar pipeline.\n",
    "BART uses an encoder-decoder architecture and is trained to reconstruct corrupted inputs.\n",
    "We use the facebook/bart-large-cnn checkpoint, which has been fine-tuned on the CNN/daily mail dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e0554a42",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
     ]
    }
   ],
   "source": [
    "pipe = pipeline(\"summarization\", model=\"facebook/bart-large-cnn\")\n",
    "pipe_out = pipe(sample_text)\n",
    "summaries[\"bart\"] = \"\\n\".join(sent_tokenize(pipe_out[0][\"summary_text\"]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "926b7864",
   "metadata": {},
   "source": [
    "Loading the PEGASUS model. It has an encoder-decoder architecture. Its pre-training objective is to reconstruct masked sentences.\n",
    "The model has the <n> token for newlines hence sent_tokenize(), which splits into sentences is not required."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1dc01237",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\USER\\anaconda3\\lib\\site-packages\\torch\\_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n"
     ]
    }
   ],
   "source": [
    "pipe = pipeline(\"summarization\", model=\"google/pegasus-cnn_dailymail\")\n",
    "pipe_out = pipe(sample_text)\n",
    "summaries[\"pegasus\"] = pipe_out[0][\"summary_text\"].replace(\" .<n>\", \".\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17922989",
   "metadata": {},
   "source": [
    "We print the summaries stored in dataset, the baseline summary and the summaries generated by GPT-2, T5, BART and PEGASUS and compare them qualitatively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "766636f7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GROUND TRUTH\n",
      "Mentally ill inmates in Miami are housed on the \"forgotten floor\"\n",
      "Judge Steven Leifman says most are there as a result of \"avoidable felonies\"\n",
      "While CNN tours facility, patient shouts: \"I am the son of the president\"\n",
      "Leifman says the system is unjust and he's fighting for change .\n",
      "\n",
      "BASELINE\n",
      "Editor's note: In our Behind the Scenes series, CNN correspondents share their experiences in covering news and analyze the stories behind the events.\n",
      "Here, Soledad O'Brien takes users inside a jail where many of the inmates are mentally ill. An inmate housed on the \"forgotten floor,\" where many mentally ill inmates are housed in Miami before trial.\n",
      "MIAMI, Florida (CNN) -- The ninth floor of the Miami-Dade pretrial detention facility is dubbed the \"forgotten floor.\"\n",
      "\n",
      "GPT2\n",
      "To get to the jail, you go up a flight of stairs, pass a metal detector, and go down a hall.\n",
      "The inmates are so scared that they'll be thrown in jail for not turning up your appointment.\n",
      "The first room to the right is where the mentally ill inmates are housed, and the next room to the bottom is where the regular inmates go.\n",
      "At the end,\n",
      "\n",
      "T5\n",
      "mentally ill inmates are housed on the ninth floor of a florida jail .\n",
      "most face drug charges or charges of assaulting an officer .\n",
      "judge says arrests often result from confrontations with police .\n",
      "one-third of all people in Miami-dade county jails are mental ill .\n",
      "\n",
      "BART\n",
      "Mentally ill inmates are housed on the \"forgotten floor\" of Miami-Dade jail.\n",
      "Most often, they face drug charges or charges of assaulting an officer.\n",
      "Judge Steven Leifman says the arrests often result from confrontations with police.\n",
      "He says about one-third of all people in the county jails are mentally ill.\n",
      "\n",
      "PEGASUS\n",
      "Mentally ill inmates in Miami are housed on the \"forgotten floor\"<n>The ninth floor is where they're held until they're ready to appear in court.\n",
      "Most often, they face drug charges or charges of assaulting an officer.\n",
      "They end up on the ninth floor severely mentally disturbed .\n",
      "\n"
     ]
    }
   ],
   "source": [
    "#print(\"ARTICLE\")\n",
    "#print(dataset[\"train\"][1][\"article\"][:2000])\n",
    "#print(\"\")\n",
    "print(\"GROUND TRUTH\")\n",
    "print(dataset[\"train\"][1][\"highlights\"])\n",
    "print(\"\")\n",
    "for model_name in summaries:\n",
    "    print(model_name.upper())\n",
    "    print(summaries[model_name])\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0627177f",
   "metadata": {},
   "source": [
    "We now compare the summaries generated using the ROUGE metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7b7b09de",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\USER\\AppData\\Local\\Temp\\ipykernel_23300\\2048908469.py:2: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
      "  rouge_metric = load_metric(\"rouge\")\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_metric\n",
    "rouge_metric = load_metric(\"rouge\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5fd5fe67",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: rouge_score in c:\\users\\user\\anaconda3\\lib\\site-packages (0.1.2)\n",
      "Requirement already satisfied: six>=1.14.0 in c:\\users\\user\\anaconda3\\lib\\site-packages (from rouge_score) (1.16.0)\n",
      "Requirement already satisfied: numpy in c:\\users\\user\\anaconda3\\lib\\site-packages (from rouge_score) (1.22.4)\n",
      "Requirement already satisfied: absl-py in c:\\users\\user\\anaconda3\\lib\\site-packages (from rouge_score) (1.4.0)\n",
      "Requirement already satisfied: nltk in c:\\users\\user\\anaconda3\\lib\\site-packages (from rouge_score) (3.7)\n",
      "Requirement already satisfied: click in c:\\users\\user\\anaconda3\\lib\\site-packages (from nltk->rouge_score) (8.0.4)\n",
      "Requirement already satisfied: tqdm in c:\\users\\user\\anaconda3\\lib\\site-packages (from nltk->rouge_score) (4.66.4)\n",
      "Requirement already satisfied: joblib in c:\\users\\user\\anaconda3\\lib\\site-packages (from nltk->rouge_score) (1.1.0)\n",
      "Requirement already satisfied: regex>=2021.8.3 in c:\\users\\user\\anaconda3\\lib\\site-packages (from nltk->rouge_score) (2022.3.15)\n",
      "Requirement already satisfied: colorama in c:\\users\\user\\anaconda3\\lib\\site-packages (from click->nltk->rouge_score) (0.4.4)\n"
     ]
    }
   ],
   "source": [
    "!pip install rouge_score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f17dd90",
   "metadata": {},
   "source": [
    "We compare the models using the ROUGE metric on the CNN/dailymail dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4a5cb537",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rouge1</th>\n",
       "      <th>rouge2</th>\n",
       "      <th>rougeL</th>\n",
       "      <th>rougeLsum</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>baseline</th>\n",
       "      <td>0.365079</td>\n",
       "      <td>0.145161</td>\n",
       "      <td>0.206349</td>\n",
       "      <td>0.285714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>gpt2</th>\n",
       "      <td>0.271186</td>\n",
       "      <td>0.051724</td>\n",
       "      <td>0.152542</td>\n",
       "      <td>0.271186</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>t5</th>\n",
       "      <td>0.382979</td>\n",
       "      <td>0.130435</td>\n",
       "      <td>0.255319</td>\n",
       "      <td>0.382979</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>bart</th>\n",
       "      <td>0.475248</td>\n",
       "      <td>0.222222</td>\n",
       "      <td>0.316832</td>\n",
       "      <td>0.415842</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            rouge1    rouge2    rougeL  rougeLsum\n",
       "baseline  0.365079  0.145161  0.206349   0.285714\n",
       "gpt2      0.271186  0.051724  0.152542   0.271186\n",
       "t5        0.382979  0.130435  0.255319   0.382979\n",
       "bart      0.475248  0.222222  0.316832   0.415842"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_metric\n",
    "import pandas as pd\n",
    "rouge_metric = load_metric(\"rouge\",trust_remote_code=True)\n",
    "reference = dataset[\"train\"][1][\"highlights\"]\n",
    "records = []\n",
    "rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n",
    "for model_name in summaries:\n",
    "    rouge_metric.add(prediction=summaries[model_name], reference=reference)\n",
    "    score = rouge_metric.compute()\n",
    "    rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)\n",
    "    records.append(rouge_dict)\n",
    "pd.DataFrame.from_records(records, index=summaries.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "271dda4f",
   "metadata": {},
   "source": [
    "We evaluate the baseline using ROUGE metric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "70ae02d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate_summaries_baseline(dataset, metric,column_text=\"article\",column_summary=\"highlights\"):\n",
    "    summaries = [three_sentence_summary(text) for text in dataset[column_text]]\n",
    "    metric.add_batch(predictions=summaries,\n",
    "    references=dataset[column_summary])\n",
    "    score = metric.compute()\n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eceeb981",
   "metadata": {},
   "source": [
    "We use only 500 data points to evaluate the baseline as doing so on the entire dataset is computationally expensive."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c71ba065",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rouge1</th>\n",
       "      <th>rouge2</th>\n",
       "      <th>rougeL</th>\n",
       "      <th>rougeLsum</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>baseline</th>\n",
       "      <td>0.388588</td>\n",
       "      <td>0.168604</td>\n",
       "      <td>0.242256</td>\n",
       "      <td>0.351425</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            rouge1    rouge2    rougeL  rougeLsum\n",
       "baseline  0.388588  0.168604  0.242256   0.351425"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_sampled = dataset[\"test\"].shuffle(seed=42).select(range(500))\n",
    "score = evaluate_summaries_baseline(test_sampled, rouge_metric)\n",
    "rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)\n",
    "pd.DataFrame.from_dict(rouge_dict, orient=\"index\", columns=[\"baseline\"]).T"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c4be1ae",
   "metadata": {},
   "source": [
    "We now evaluate the PEGASUS model on a subset of the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "5a0b0d7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "import torch\n",
    "\n",
    "#device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "device = \"cpu\"\n",
    "def chunks(list_of_elements, batch_size):\n",
    "    \"\"\"Yield successive batch-sized chunks from list_of_elements.\"\"\"\n",
    "    for i in range(0, len(list_of_elements), batch_size):\n",
    "        yield list_of_elements[i : i + batch_size]\n",
    "\n",
    "def evaluate_summaries_pegasus(dataset, metric, model, tokenizer, \n",
    "                               batch_size=16, device=device, \n",
    "                               column_text=\"article\", \n",
    "                               column_summary=\"highlights\"):\n",
    "    article_batches = list(chunks(dataset[column_text], batch_size))\n",
    "    target_batches = list(chunks(dataset[column_summary], batch_size))\n",
    "\n",
    "    for article_batch, target_batch in tqdm(\n",
    "        zip(article_batches, target_batches), total=len(article_batches)):\n",
    "        \n",
    "        inputs = tokenizer(article_batch, max_length=1024,  truncation=True, \n",
    "                        padding=\"max_length\", return_tensors=\"pt\")\n",
    "        \n",
    "        summaries = model.generate(input_ids=inputs[\"input_ids\"].to(device),\n",
    "                         attention_mask=inputs[\"attention_mask\"].to(device), \n",
    "                         length_penalty=0.8, num_beams=8, max_length=128)\n",
    "        \n",
    "        decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True, \n",
    "                                clean_up_tokenization_spaces=True) \n",
    "               for s in summaries]\n",
    "        decoded_summaries = [d.replace(\"\", \" \") for d in decoded_summaries]\n",
    "        metric.add_batch(predictions=decoded_summaries, references=target_batch)\n",
    "        \n",
    "    score = metric.compute()\n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ce5e1d8c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\USER\\anaconda3\\lib\\site-packages\\torch\\_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "# hide_output\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "import pandas as pd\n",
    "\n",
    "#device=\"cpu\"\n",
    "model_ckpt = \"google/pegasus-cnn_dailymail\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "dd35d572",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "100%|███████████████████████████████████████████████████████████████████████████████| 63/63 [4:05:47<00:00, 234.09s/it]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rouge1</th>\n",
       "      <th>rouge2</th>\n",
       "      <th>rougeL</th>\n",
       "      <th>rougeLsum</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>pegasus</th>\n",
       "      <td>0.012722</td>\n",
       "      <td>0.000559</td>\n",
       "      <td>0.012589</td>\n",
       "      <td>0.012641</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           rouge1    rouge2    rougeL  rougeLsum\n",
       "pegasus  0.012722  0.000559  0.012589   0.012641"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# hide_output\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "import pandas as pd\n",
    "\n",
    "#device=\"cpu\"\n",
    "model_ckpt = \"google/pegasus-cnn_dailymail\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)\n",
    "#score = evaluate_summaries_pegasus(test_sampled, rouge_metric, \n",
    "#                                   model, tokenizer, batch_size=8)\n",
    "#rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)\n",
    "#pd.DataFrame(rouge_dict, index=[\"pegasus\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "880bac59",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide_input \n",
    "pd.DataFrame(rouge_dict, index=[\"pegasus\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7a3b738",
   "metadata": {},
   "source": [
    "# Training a Summarization Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef9e2191",
   "metadata": {},
   "source": [
    "We use the SAMSum dataset developed by Samsung, which consists of a collection of dialogues along with brief summaries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cf3ef1e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using the latest cached version of the module from C:\\Users\\USER\\.cache\\huggingface\\modules\\datasets_modules\\datasets\\samsum\\f1d7c6b7353e6de335d444e424dc002ef70d1277109031327bc9cc6af5d3d46e (last modified on Wed Jul  3 19:21:09 2024) since it couldn't be found locally at samsum, or remotely on the Hugging Face Hub.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Split lengths: [14732, 819, 818]\n",
      "Features: ['id', 'dialogue', 'summary']\n",
      "\n",
      "Dialogue:\n",
      "Hannah: Hey, do you have Betty's number?\n",
      "Amanda: Lemme check\n",
      "Hannah: <file_gif>\n",
      "Amanda: Sorry, can't find it.\n",
      "Amanda: Ask Larry\n",
      "Amanda: He called her last time we were at the park together\n",
      "Hannah: I don't know him well\n",
      "Hannah: <file_gif>\n",
      "Amanda: Don't be shy, he's very nice\n",
      "Hannah: If you say so..\n",
      "Hannah: I'd rather you texted him\n",
      "Amanda: Just text him 🙂\n",
      "Hannah: Urgh.. Alright\n",
      "Hannah: Bye\n",
      "Amanda: Bye bye\n",
      "\n",
      "Summary:\n",
      "Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.\n"
     ]
    }
   ],
   "source": [
    "dataset_samsum = load_dataset(\"samsum\",trust_remote_code=True)\n",
    "split_lengths = [len(dataset_samsum[split])for split in dataset_samsum]\n",
    "print(f\"Split lengths: {split_lengths}\")\n",
    "print(f\"Features: {dataset_samsum['train'].column_names}\")\n",
    "print(\"\\nDialogue:\")\n",
    "print(dataset_samsum[\"test\"][0][\"dialogue\"])\n",
    "print(\"\\nSummary:\")\n",
    "print(dataset_samsum[\"test\"][0][\"summary\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45e6f344",
   "metadata": {},
   "source": [
    "We evaluate PEGASUS on SAMSum dataset BEFORE fine-tuning."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "491db9b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "# hide_output\n",
    "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer\n",
    "import pandas as pd\n",
    "\n",
    "#device=\"cpu\"\n",
    "model_ckpt = \"google/pegasus-cnn_dailymail\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d3225b57",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 103/103 [6:15:02<00:00, 218.47s/it]\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'rouge_names' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Input \u001b[1;32mIn [22]\u001b[0m, in \u001b[0;36m<cell line: 6>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      2\u001b[0m rouge_metric \u001b[38;5;241m=\u001b[39m load_metric(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrouge\u001b[39m\u001b[38;5;124m\"\u001b[39m,trust_remote_code\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m      3\u001b[0m score \u001b[38;5;241m=\u001b[39m evaluate_summaries_pegasus(dataset_samsum[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m], rouge_metric, model,\n\u001b[0;32m      4\u001b[0m tokenizer, column_text\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdialogue\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m      5\u001b[0m column_summary\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msummary\u001b[39m\u001b[38;5;124m\"\u001b[39m, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m8\u001b[39m)\n\u001b[1;32m----> 6\u001b[0m rouge_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m((rn, score[rn]\u001b[38;5;241m.\u001b[39mmid\u001b[38;5;241m.\u001b[39mfmeasure) \u001b[38;5;28;01mfor\u001b[39;00m rn \u001b[38;5;129;01min\u001b[39;00m \u001b[43mrouge_names\u001b[49m)\n\u001b[0;32m      7\u001b[0m pd\u001b[38;5;241m.\u001b[39mDataFrame(rouge_dict, index\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpegasus\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n",
      "\u001b[1;31mNameError\u001b[0m: name 'rouge_names' is not defined"
     ]
    }
   ],
   "source": [
    "from datasets import load_metric\n",
    "rouge_metric = load_metric(\"rouge\",trust_remote_code=True)\n",
    "score = evaluate_summaries_pegasus(dataset_samsum[\"test\"], rouge_metric, model,\n",
    "tokenizer, column_text=\"dialogue\",\n",
    "column_summary=\"summary\", batch_size=8)\n",
    "rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)\n",
    "pd.DataFrame(rouge_dict, index=[\"pegasus\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b2c5511",
   "metadata": {},
   "source": [
    "ROUGE scores before fine-tuning on SAMSum dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9cd0c1bb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>rouge1</th>\n",
       "      <th>rouge2</th>\n",
       "      <th>rougeL</th>\n",
       "      <th>rougeLsum</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>pegasus</th>\n",
       "      <td>0.015564</td>\n",
       "      <td>0.000294</td>\n",
       "      <td>0.015572</td>\n",
       "      <td>0.015588</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "           rouge1    rouge2    rougeL  rougeLsum\n",
       "pegasus  0.015564  0.000294  0.015572   0.015588"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n",
    "rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)\n",
    "pd.DataFrame(rouge_dict, index=[\"pegasus\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9385cd37",
   "metadata": {},
   "source": [
    "# Fine-tuning PEGASUS on SAMSum dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b42f90b3",
   "metadata": {},
   "source": [
    "Plotting dialogue and summary token lengths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3c440b28",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (1044 > 1024). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsgAAAD0CAYAAACGjNCJAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAmX0lEQVR4nO3dfZglZX3n//cHEFBQAZkfgYE4qGwM+oujGRGjm0UIj6vBbNQF/Sm6RDSLWdkYFaKJT5DVvbKi/lZRogg+RMSniEhERIxrdgUHRWRAwyiQGUQYniEYEPzuH3W3FE339OmePud097xf13Wurrrrrqpvneq+z7fvc1dVqgpJkiRJnS3GHYAkSZK0kJggS5IkST0myJIkSVKPCbIkSZLUY4IsSZIk9ZggS5IkST0myBqZJB9M8hcD1v1Gkj8adkzDlqSSPGHccYzCUjlnksYryYrWdm417lhGIck1SX5v3HHowUyQNS/aH/jPk9yZ5LYk/zvJq5P86nesql5dVe8YZ5xzleTvk9zVXr9Icm9v/oNjju2tST6x1PcpLXZJnt3axtuT3JLkH5M8fdxxDVOSNb228v4k/9qb//Mxx3Z6khOX+j41N5vFf2camedV1deSPBr4d8B7gWcArxhvWJuuqg6dmE5yOrC+qt48vogkLSZJHgWcA/wxcBawNfBvgXvGGddsJQmQqvrlIPWr6km9db8BfKKqPjyk8KR5Yw+y5l1V3V5VZwP/ETgqyZPhwf85J9kxyTlJNiS5tU3vPtX2kmyR5M1Jrk1yY5KPtSR8YvnL2rKbk/xF/+uqyf+tJ9kvyfre/G5JPtfiuDrJf5nt8SZ5ZZK1rUfo7CS7TVPv2UnWJdmvzf+nJFe24z8vyWN7dav1wF/VeuTf3z6YZhvbvq3H6rYk35/Yd1v2jSTvaL1Ydyb5apKde8unfF+THAL8OfAfWy/Q93u7fOx025M2c/8GoKo+VVX3V9XPq+qrVXUZPPRbmUwaZtD+Xk9sf893JflSksck+WSSO5J8J8mK3vqV5D+3NuTO9rf++Lb+HUnOSrJ1q7vR9rjt+6Qk/wjcDbwuySX9g0vyp0m+OOibMVO7PqnuH7b258ltveOT/Li1TWcl2WnSe3ZUkn9OclOSNw0a06R9PjfJpXngG9Hf6i27JsmfJbks3bcBn06ybW/5G5Jcn+SnSf6oxfSEJMcALwHeMHEOe7tcOd32NB4myBqaqroYWE/XSzLZFsBHgccCvw78HPif02zq5e31HOBxwPYTdZPsDXyArtHZFXg0sHyQ+NIN//gS8P22zgHAcUkOHmT9to39gf8GvKjt/1rgzCnqHQJ8CvjDqvpGksPpksz/ACwD/ldb3vdc4OnAb7XtDxxX2+dy4MvAicBOwJ8Bn0uyrFftxXQ9/P8PXY/Wn7V1p31fq+orwF8Bn66q7avqKTNtTxL/BNyf5IwkhybZcQ7bOAJ4Kd3f4uOB/0PXju4EXAm8ZVL9g4HfBvYF3gCcCvx/wB7Ak4EjW71B2uOXAscAjwTeB+yZ5DcnLf/YLI7l5UzTrvcleQXwLuD3qupy4E+A59N9S7kbcCvw/kmrPRv4Dbo2/S8nxTmjJE8FTgNeBTwG+BBwdpJtetVeBBwC7EnXRr+8rXsI8KfA7wFPAPabWKGqTgU+Cfz31nY+b6btaXxMkDVsP6VrvB+kqm6uqs9V1d1VdSdwEl2DN5WXAO+uqp9U1V3ACcARrWflBcCXqupbVXUv8JdADRjb04FlVfX2qrq3qn4C/A3dh9CgXgKcVlXfrap7WmzP7PfkAC+ka2APbf80ALwa+G9VdWVV3UeXcK5MrxcZeGdV3VZV/wxcCKycRVzQfRCeW1XnVtUvq+p8YDVwWK/OR6vqn6rq53Rf+07sY67v63TbkzZrVXUHXeJWdO3MhnTfOO0yi818tKp+XFW3A38P/LiqvtbakM8AT51U/79X1R1VtQa4HPhqa0cn1n9qi22Q9vj0qlpTVfe1tu7TdG0MSZ4ErKAbQjKojbXrE44DXg/sV1VrW9mrgTdV1foWx1uBF0xa722th/77dB0g/X/iB3EM8KGquqj19p9BNxRm316d91XVT6vqFrqOlpWt/EV052lNVd3d4hvEdNvTmJgga9iWA7dMLkzyiCQfal+v3QF8E9ghyZZTbGM3up7ZCdfSjZ/fpS1bN7GgNUg3DxjbY4Hd2ldotyW5ja5XdzYfWA+KrTX0N/PgXuzjgLNa70d/3+/t7fcWIJPW+1lv+m66HpbZeCzwwknH92y6HuGZ9jHX93VTY5aWrPYP8curane6HtzdgPfMYhM39KZ/PsX85L+3geoP2B6v48HOAF6cJHS9x2e1hHVQG2vXJ7weeH9Vre+VPRb4Qq9NuxK4f9J689F2vm5S27lHi3mmfTyo7eSh79t0bDsXGBNkDU26q7OXA9+aYvHr6L4Ce0ZVPQr43YnVpqj7U7oGa8KvA/fRNfbXA/2xcg+n+0pswr8Aj+jN/1pveh1wdVXt0Hs9sqr6PawzeVBsSbZr+7+uV+eFwPOTvHbSvl81ad8Pr6r/PYt9z2Qd8PFJ+9iuqt45wLozva+D9tJLmkJV/RA4nS5Rho23VcM2SHv8oL/5qvo2cC/dELoXAx+f5T431q5POAh4c5I/7JWto/s2rt+ubVtV/TZ3U60DTpq0j0dU1eRhcFN5UNtJl1j32XYuEibImndJHpXkuXRjcT9RVT+Yotoj6XowbmsXWEweO9f3KeC/JtkzyfY8MP71PuCzwPOS/E66C07eyoMb9UuBw5LslOTX6HpzJ1wM3JnkjUkenmTLdhHIbG679CngFUlWtvFpfwVcVFXX9Or8lG4s3GuT/HEr+yBwQvtqkiSPTvLCWex3si2SbNt7bQN8gu69Obgd27bpLlKc8mLISWZ6X28AVqR3Gz9J00vyxCSvm/j7S7IH3Rjgb7cqlwK/m+TX012sdsIIw5tNe9z3Mbpxw7+oqqk6QjZmY+36hDV043Lfn+T3W9kHgZMmhqMlWdau6ZirLSe1nVvTDYF5dZJnpLNdkn+f5JEDbO8sus+E30zyCGDyvf9voBtzrQXODzfNpy8luZPuv+83Ae9m+lu8vQd4OHAT3QfEVzay3dPoeie+CVwN/CvdhRq0sXV/QpeMXw/cBdzIA7dO+jjdGLRrgK/SjZujrXs/3YVwK9t2bwI+THdB2kCq6mt0DeDn2v4fzxRjmNs44gOA45P8UVV9ge7CkzPbV5qXA4dOXm8WjqT7gJt4/biq1gETFwNuoDsvr2eAv/sB3tfPtJ83J/nuJsQtbS7upLvt5UVJ/oWu3bucrveWdo3Ap4HLgEuY3XjeTfUeBm+P+z5O1wM+l3uiT9uu97VxxM8F/ibJoXS3Dz0b+Gr7vPk23fs6V8fz4Lbz61W1GnglXfJ/K7CWAS+aq6q/p7uI8cK23sQ/QBNt50eAvdvQjb/bhLg1ZKmyt19LR+uJuA3Yq6quHnM4S4bvq6TJ2tCrG4GnVdVV445nIWp30Lgc2GZS77gWOHuQtegleV67yGQ74K+BH9D1GGsT+L5KmsEfA98xOX6wJH+QZJt0t/J7F90dgUyOFxkTZC0Fh9ON8/0psBdwRPnVyHzwfZU0pSTXAK+lDRHRg7yKrmf9x3R32PjjjVfXQuQQC0mSJKnHHmRJkiSpZ6uZqyw+O++8c61YsWLcYUjSSF1yySU3VdWymWs+lO2mpM3RdO3mkkyQV6xYwerVq8cdhiSNVJJrZ641NdtNSZuj6dpNh1hIkiRJPSbIkiRJUo8JsiRJktRjgixJkiT1mCBLkiRJPSbIkiRJUo8JsiRJktSzJO+DPA6rTjyfm+66d07r7rz91qx+84HzHJEkSZLmwh7keTLX5HhT15UkSdL8MkGWJEmSekyQJUmSpJ6hJ8hJtkzyvSTntPk9k1yUZG2STyfZupVv0+bXtuUrets4oZX/KMnBw45ZkiRJm69R9CC/FriyN/8u4OSqegJwK3B0Kz8auLWVn9zqkWRv4AjgScAhwAeSbDmCuCVJkrQZGmqCnGR34N8DH27zAfYHPtuqnAE8v00f3uZpyw9o9Q8Hzqyqe6rqamAtsM8w45YkSdLma9g9yO8B3gD8ss0/Britqu5r8+uB5W16ObAOoC2/vdX/VfkU6/xKkmOSrE6yesOGDfN8GJK09NhuStLUhpYgJ3kucGNVXTKsffRV1alVtaqqVi1btmwUu5SkRc12U5KmNswHhTwL+P0khwHbAo8C3gvskGSr1ku8O3Bdq38dsAewPslWwKOBm3vlE/rrSJIkSfNqaD3IVXVCVe1eVSvoLrL7elW9BLgQeEGrdhTwxTZ9dpunLf96VVUrP6Ld5WJPYC/g4mHFLUmSpM3bOB41/UbgzCQnAt8DPtLKPwJ8PMla4Ba6pJqqWpPkLOAK4D7g2Kq6f/RhS5IkaXMwkgS5qr4BfKNN/4Qp7kJRVf8KvHCa9U8CThpehJIkSVLHJ+lJkiRJPSbIkiRJUo8JsiRJktRjgixJkiT1mCBLkiRJPSbIkiRJUo8JsiRJktRjgixJkiT1mCBLkiRJPSbIkiRJUo8JsiRJktRjgixJkiT1DC1BTrJtkouTfD/JmiRva+WnJ7k6yaXttbKVJ8n7kqxNclmSp/W2dVSSq9rrqGHFLEmSJG01xG3fA+xfVXcleRjwrSR/35a9vqo+O6n+ocBe7fUM4BTgGUl2At4CrAIKuCTJ2VV16xBjlyRJ0mZqaD3I1bmrzT6svWojqxwOfKyt921ghyS7AgcD51fVLS0pPh84ZFhxS5IkafM21DHISbZMcilwI12Se1FbdFIbRnFykm1a2XJgXW/19a1suvLJ+zomyeokqzds2DDfhyJJS47tpiRNbagJclXdX1Urgd2BfZI8GTgBeCLwdGAn4I3ztK9Tq2pVVa1atmzZfGxSkpY0201JmtpI7mJRVbcBFwKHVNX1bRjFPcBHgX1ateuAPXqr7d7KpiuXJEmS5t0w72KxLMkObfrhwIHAD9u4YpIEeD5weVvlbOBl7W4W+wK3V9X1wHnAQUl2TLIjcFArkyRJkubdMO9isStwRpIt6RLxs6rqnCRfT7IMCHAp8OpW/1zgMGAtcDfwCoCquiXJO4DvtHpvr6pbhhi3JEmSNmNDS5Cr6jLgqVOU7z9N/QKOnWbZacBp8xqgJEmSNAWfpCdJkiT1mCBLkiRJPSbIkiRJUo8JsiRJktRjgixJkiT1mCBLkiRJPSbIkiRJUo8JsiRJktQzzCfpLUqrTjyfm+66d9xhSJIkaUzsQZ7E5FiSJGnzZoIsSZIk9ZggS5IkST1DS5CTbJvk4iTfT7Imydta+Z5JLkqyNsmnk2zdyrdp82vb8hW9bZ3Qyn+U5OBhxSxJkiQNswf5HmD/qnoKsBI4JMm+wLuAk6vqCcCtwNGt/tHAra385FaPJHsDRwBPAg4BPpBkyyHGLUmSpM3Y0BLk6tzVZh/WXgXsD3y2lZ8BPL9NH97macsPSJJWfmZV3VNVVwNrgX2GFbckSZI2b0Mdg5xkyySXAjcC5wM/Bm6rqvtalfXA8ja9HFgH0JbfDjymXz7FOv19HZNkdZLVGzZsGMLRSNLSYrspSVMbaoJcVfdX1Upgd7pe3ycOcV+nVtWqqlq1bNmyYe1GkpYM201JmtpI7mJRVbcBFwLPBHZIMvGAkt2B69r0dcAeAG35o4Gb++VTrCNJkiTNq2HexWJZkh3a9MOBA4Er6RLlF7RqRwFfbNNnt3na8q9XVbXyI9pdLvYE9gIuHlbckiRJ2rwN81HTuwJntDtObAGcVVXnJLkCODPJicD3gI+0+h8BPp5kLXAL3Z0rqKo1Sc4CrgDuA46tqvuHGLckSZI2Y0NLkKvqMuCpU5T/hCnuQlFV/wq8cJptnQScNN8xSpIkSZP5JD1JkiSpxwRZkiRJ6jFBliRJknqGeZGeZmHF8V+e9To7b781q9984BCikSRJ2nzZg7yI3XTXveMOQZIkackxQZYkSZJ6TJAlSZKkHhNkSZIkqccEWZIkSeoxQZYkSZJ6TJAlSZKkHhNkSZIkqWdoCXKSPZJcmOSKJGuSvLaVvzXJdUkuba/DeuuckGRtkh8lObhXfkgrW5vk+GHFLEmSJA3zSXr3Aa+rqu8meSRwSZLz27KTq+qv+5WT7A0cATwJ2A34WpJ/0xa/HzgQWA98J8nZVXXFEGOXJEnSZmpoCXJVXQ9c36bvTHIlsHwjqxwOnFlV9wBXJ1kL7NOWra2qnwAkObPVNUGWJEnSvBvJGOQkK4CnAhe1otckuSzJaUl2bGXLgXW91da3sunKJ+/jmCSrk6zesGHDfB+CJC05tpuSNLWBEuQkzxqkbJp1twc+BxxXVXcApwCPB1bS9TD/j0GD3ZiqOrWqVlXVqmXLls3HJiVpSbPdlKSpDdqD/P8PWPYgSR5Glxx/sqo+D1BVN1TV/VX1S+BveGAYxXXAHr3Vd29l05VLkiRJ826jY5CTPBP4HWBZkj/tLXoUsOUM6wb4CHBlVb27V75rG58M8AfA5W36bOBvk7yb7iK9vYCLgQB7JdmTLjE+AnjxYIcnSZIkzc5MF+ltDWzf6j2yV34H8IIZ1n0W8FLgB0kubWV/DhyZZCVQwDXAqwCqak2Ss+guvrsPOLaq7gdI8hrgPLqk/LSqWjPAsUmSJEmzttEEuar+AfiHJKdX1bWz2XBVfYuu93eyczeyzknASVOUn7ux9SRJkqT5Muht3rZJciqwor9OVe0/jKAkSZKkcRk0Qf4M8EHgw8D9wwtHkiRJGq9BE+T7quqUoUYiSZIkLQCD3ubtS0n+c5Jdk+w08RpqZJIkSdIYDNqDfFT7+fpeWQGPm99wJEmSpPEaKEGuqj2HHYgkSZK0EAyUICd52VTlVfWx+Q1HkiRJGq9Bh1g8vTe9LXAA8F3ABFmSJElLyqBDLP6kP59kB+DMYQQkSZIkjdOgPciT/QvguGRJ0qKy6sTzuemue2e93s7bb83qNx84hIgkLUSDjkH+Et1dKwC2BH4TOGtYQUmSNAxzSY43ZT1Ji9OgPch/3Zu+D7i2qtYPIR5JkiRprAZ6UEhV/QPwQ+CRwI7AjP9KJ9kjyYVJrkiyJslrW/lOSc5PclX7uWMrT5L3JVmb5LIkT+tt66hW/6okR023T0mSJGlTDZQgJ3kRcDHwQuBFwEVJXjDDavcBr6uqvYF9gWOT7A0cD1xQVXsBF7R5gEOBvdrrGOCUtu+dgLcAzwD2Ad4ykVRLkiRJ823QIRZvAp5eVTcCJFkGfA347HQrVNX1wPVt+s4kVwLLgcOB/Vq1M4BvAG9s5R+rqgK+nWSHJLu2uudX1S1t3+cDhwCfGvgoJUnaRCuO//Kc1vMCP2nxGagHGdhiIjlubp7FuiRZATwVuAjYpSXPAD8DdmnTy4F1vdXWt7Lpyifv45gkq5Os3rBhw6ChSdJmy3ZzNLzAT1p8Bk1yv5LkvCQvT/Jy4MvAuYOsmGR74HPAcVV1R39Z6y2uKVecpao6tapWVdWqZcuWzccmJWlJs92UpKltNEFO8oQkz6qq1wMfAn6rvf4PcOpMG0/yMLrk+JNV9flWfEMbOkH7OdEzfR2wR2/13VvZdOWSJEnSvJtpDPJ7gBMAWoL7eYAk/29b9rzpVkwS4CPAlVX17t6is4GjgHe2n1/slb8myZl0F+TdXlXXJzkP+KvehXkHTcQkx8RJkiTNt5kS5F2q6geTC6vqB21c8cY8C3gp8IMkl7ayP6dLjM9KcjRwLd1dMaAbsnEYsBa4G3hF29ctSd4BfKfVe/vEBXuaO8fESZIkTW2mBHmHjSx7+MZWrKpvAZlm8QFT1C/g2Gm2dRpw2sb2J0naPMz1cdGSNKiZLtJbneSVkwuT/BFwyXBCkiRpeibHkoZtph7k44AvJHkJDyTEq4CtgT8YYlySJEnSWGw0Qa6qG4DfSfIc4Mmt+MtV9fWhRyZJkiSNwUBP0quqC4ELhxyLJEmSNHYDPw1PkiRJ2hyYIEuSJEk9JsiSJElSz0BjkCVJ0tzN5amnPvFUGh97kCVJWoC837M0PibIkiRJUo8JsiRJktRjgixJkiT1DC1BTnJakhuTXN4re2uS65Jc2l6H9ZadkGRtkh8lObhXfkgrW5vk+GHFK0mSJMFwe5BPBw6ZovzkqlrZXucCJNkbOAJ4UlvnA0m2TLIl8H7gUGBv4MhWV5IkSRqKod3mraq+mWTFgNUPB86sqnuAq5OsBfZpy9ZW1U8AkpzZ6l4x3/FKkiRJMJ4xyK9JclkbgrFjK1sOrOvVWd/Kpit/iCTHJFmdZPWGDRuGEbckLSm2m5I0tVEnyKcAjwdWAtcD/2O+NlxVp1bVqqpatWzZsvnarCQtWbabkjS1kT5Jr6pumJhO8jfAOW32OmCPXtXdWxkbKZckSZLm3UgT5CS7VtX1bfYPgIk7XJwN/G2SdwO7AXsBFwMB9kqyJ11ifATw4lHGLEnSuMzlEdXgY6qlTTW0BDnJp4D9gJ2TrAfeAuyXZCVQwDXAqwCqak2Ss+guvrsPOLaq7m/beQ1wHrAlcFpVrRlWzJIkLQU+plraNMO8i8WRUxR/ZCP1TwJOmqL8XODceQxNkiRJmpZP0pMkSZJ6TJAlSZKkHhNkSZIkqccEWZIkSeoxQZYkSZJ6TJAlSZKkHhNkSZIkqWekT9KTJEmjMZen8PkEPqljD7IkSQJ8Ap80wR7kzZi9C5IkSQ9lD7Jmxd4FSZK01JkgS5IkST1DS5CTnJbkxiSX98p2SnJ+kqvazx1beZK8L8naJJcleVpvnaNa/auSHDWseCVJkiQYbg/y6cAhk8qOBy6oqr2AC9o8wKHAXu11DHAKdAk18BbgGcA+wFsmkmpJkiRpGIaWIFfVN4FbJhUfDpzRps8Ant8r/1h1vg3skGRX4GDg/Kq6papuBc7noUm3JEmSNG9GPQZ5l6q6vk3/DNilTS8H1vXqrW9l05U/RJJjkqxOsnrDhg3zG7UkLUG2m5I0tbFdpFdVBdQ8bu/UqlpVVauWLVs2X5uVpCXLdlOSpjbqBPmGNnSC9vPGVn4dsEev3u6tbLpySZIkaShGnSCfDUzcieIo4Iu98pe1u1nsC9zehmKcBxyUZMd2cd5BrUySJEkaiqE9SS/Jp4D9gJ2TrKe7G8U7gbOSHA1cC7yoVT8XOAxYC9wNvAKgqm5J8g7gO63e26tq8oV/kiRpnszlKavgk1a1tAwtQa6qI6dZdMAUdQs4dprtnAacNo+hSZKkeeaTVrWU+CQ9SZIkqccEWZIkSeoxQZYkSZJ6hjYGWZKkjVl14vmOW5W0INmDLEkaC5NjSQuVCbIkSZLUY4IsSZIk9ZggS5IkST1epKdZ8ylLkiRpKbMHWSPjBTmSJGkxMEGWJEmSekyQJUmSpJ6xjEFOcg1wJ3A/cF9VrUqyE/BpYAVwDfCiqro1SYD3AocBdwMvr6rvjiNuSZI0vblco+L1KVqIxnmR3nOq6qbe/PHABVX1ziTHt/k3AocCe7XXM4BT2k9JkrTI3XTXvV78rQVnIQ2xOBw4o02fATy/V/6x6nwb2CHJrmOIT5IkLSBe/K1hGVeCXMBXk1yS5JhWtktVXd+mfwbs0qaXA+t6665vZQ+S5Jgkq5Os3rBhw7DilqQlw3ZTkqY2rgT52VX1NLrhE8cm+d3+wqoquiR6YFV1alWtqqpVy5Ytm8dQJWlpst2UpKmNJUGuquvazxuBLwD7ADdMDJ1oP29s1a8D9uitvnsrkyRJkubdyC/SS7IdsEVV3dmmDwLeDpwNHAW8s/38YlvlbOA1Sc6kuzjv9t5QDEmStBnzzhkahnHcxWIX4Avd3dvYCvjbqvpKku8AZyU5GrgWeFGrfy7dLd7W0t3m7RWjD1nzxYZMkjRuXtynmYw8Qa6qnwBPmaL8ZuCAKcoLOHYEoWmBsiGTJEmjtJBu8yZJkiSN3TgfFCJJkjQWPpxEG2MPsiRJ0oAc9rd5sAdZkiRpFrzgfOmzB1mSJGnI7HleXEyQJUmSpB6HWGhR8GIKSZI0KibIWtL8SkuStFDY2bN4OMRCkiRpAbOzZ/TsQZYkSVrgvHPGaJkga8mzUZEkbY7seZ47E2RpCjYqkqSlwHHPc7NoEuQkhwDvBbYEPlxV7xxzSFribFQkSZurzb2jaFEkyEm2BN4PHAisB76T5OyqumK8kUkPddNd9zqsQ5K06G3On2WLIkEG9gHWVtVPAJKcCRwOmCBryZhrYg0QoOaw3lJpyDQ+q048f7PvaZL0gE35LFtIn0mLJUFeDqzrza8HntGvkOQY4Jg2e1eSH81xXzsDN81x3YVuKR8bLO3jG8qxXQvkL+Z7q3OylM8djO74HjubyrNoNxfi+TGmwRjT4BZiXJtVTJvwmbQpMU3Zbi6WBHlGVXUqcOqmbifJ6qpaNQ8hLThL+dhgaR/fUj428PjGZdB2cyHGb0yDMabBLcS4jGkww4hpsTwo5Dpgj9787q1MkiRJmleLJUH+DrBXkj2TbA0cAZw95pgkSZK0BC2KIRZVdV+S1wDn0d3m7bSqWjOk3W3yMI0FbCkfGyzt41vKxwYe30K3EOM3psEY0+AWYlzGNJh5jylVc7n2XZIkSVqaFssQC0mSJGkkTJAlSZKkHhPkJskhSX6UZG2S48cdz1wk2SPJhUmuSLImyWtb+U5Jzk9yVfu5YytPkve1Y74sydPGewQzS7Jlku8lOafN75nkonYMn24XcZJkmza/ti1fMdbAB5BkhySfTfLDJFcmeeZSOXdJ/mv7nbw8yaeSbLuYz12S05LcmOTyXtmsz1WSo1r9q5IcNY5j2ZiF0i7O5v0eYUyzam9HFNO2SS5O8v0W09ta+ZR/a6M0aNs9wniuSfKDJJcmWd3Kxv07NfBnwIji+Y32/ky87khy3AJ4nwb+PNkUJsg86FHWhwJ7A0cm2Xu8Uc3JfcDrqmpvYF/g2HYcxwMXVNVewAVtHrrj3au9jgFOGX3Is/Za4Mre/LuAk6vqCcCtwNGt/Gjg1lZ+cqu30L0X+EpVPRF4Ct1xLvpzl2Q58F+AVVX1ZLoLbY9gcZ+704FDJpXN6lwl2Ql4C91Dj/YB3jLqD5qNWWDt4ukM/n6Pymzb21G4B9i/qp4CrAQOSbIv0/+tjdKgbfcoPaeqVvbunzvu36nZfAYMXVX9qL0/K4HfBu4GvjDOmObweTJ3VbXZv4BnAuf15k8AThh3XPNwXF8EDgR+BOzaynYFftSmPwQc2av/q3oL8UV3/+sLgP2Bc+iesHwTsNXk80h3x5NntumtWr2M+xg2cmyPBq6eHONSOHc88CTMndq5OAc4eLGfO2AFcPlczxVwJPChXvmD6o37tdDaxUHf7zHGt9H2dgzxPAL4Lt0/YFP+rY0wloHb7hHGdA2w86SysZ272X4GjOH36SDgH8cd02w/TzblZQ9yZ6pHWS8fUyzzon0t/VTgImCXqrq+LfoZsEubXmzH/R7gDcAv2/xjgNuq6r4234//V8fWlt/e6i9UewIbgI+2ryE/nGQ7lsC5q6rrgL8G/hm4nu5cXMLSOXcTZnuuFvo5XOjxTfd+j9yA7e2oYtkyyaXAjcD5wI+Z/m9tVN7D4G33qBTw1SSXpHvkOoz33M32M2DUjgA+1abHFtMcPk/mzAR5CUqyPfA54LiquqO/rLp/rxbdvf2SPBe4saouGXcsQ7IV8DTglKp6KvAvTPraahGfux2Bw+k+AHYDtuOhX5cvKYv1XC1W43y/F1p7W1X3V/eV+O50Q3eeOMr9T7aA2+5nV9XT6IYQHZvkd/sLx3DuFuxnQBvP+/vAZyYvG3VMo/w8MUHuLJlHWSd5GF1j/cmq+nwrviHJrm35rnQ9C7C4jvtZwO8nuQY4k+6ruvcCOySZeOBNP/5fHVtb/mjg5lEGPEvrgfVVdVGb/yxdY7kUzt3vAVdX1Yaq+gXwebrzuVTO3YTZnquFfg4XenzTvd8jM8v2dqSq6jbgQrqvm6f7WxuF2bbdI9F6IqmqG+nG1e7DeM/dbD8DRulQ4LtVdUObH2dMs/08mTMT5M6SeJR1kgAfAa6sqnf3Fp0NTFwhfxTdWLmJ8pelsy9we+9rkwWlqk6oqt2ragXd+fl6Vb2E7gPgBa3a5GObOOYXtPoLtkevqn4GrEvyG63oAOAKlsC5o/sqbN8kj2i/oxPHtiTOXc9sz9V5wEFJdmy9Ige1soViobeL073fIzGH9nYUMS1LskObfjjdmOgrmf5vbejm0HYPXZLtkjxyYprub+9yxnju5vAZMEpH8sDwChhvTLP9PJm7UQ2sXugv4DDgn+jGa71p3PHM8RieTfdVx2XApe11GN14rwuAq4CvATu1+qG7Sv3HwA/orgod+3EMcJz7Aee06ccBFwNr6b7+2aaVb9vm17bljxt33AMc10pgdTt/fwfsuFTOHfA24Id0H0IfB7ZZzOeO7sPieuAXdD0/R8/lXAH/qR3nWuAV4z6uKY5zQbSLs3m/RxjTrNrbEcX0W8D3WkyXA3/Zyqf8WxvDeZyx7R5RHI8Dvt9eayZ+txfA79TAnwEjjGk7um/wHt0rG3dMA3+ebMrLR01LkiRJPQ6xkCRJknpMkCVJkqQeE2RJkiSpxwRZkiRJ6jFBliRJknpMkKVJktw15O0fl+QRo9qfJA2TbaaWIhNkafSOAx4xUyVJEmCbqTHYauYqkpI8nu5hD8uAu4FXVtUPk5wO3AGsAn4NeENVfTbJFsD/pHus6jq6BxycRvfs+N2AC5PcVFXPads/CXgu8HPg8HrgkZ6StOjYZmqxswdZGsypwJ9U1W8DfwZ8oLdsV7qnaj0XeGcr+w/ACmBv4KXAMwGq6n3AT4HnTDT0dE8q+nZVPQX4JvDKoR6JJA2fbaYWNXuQpRkk2R74HeAz3aPfge7RlhP+rqp+CVyRZJdW9mzgM638Z0ku3Mgu7gXOadOXAAfOW/CSNGK2mVoKTJClmW0B3FZVK6dZfk9vOtPU2Zhf1APPfL8f/y4lLW62mVr0HGIhzaCq7gCuTvJCgHSeMsNq/wj8YZItWg/Jfr1ldwKPHEqwkjRmtplaCkyQpYd6RJL1vdefAi8Bjk7yfWANcPgM2/gcsB64AvgE8F3g9rbsVOArM3yFKEmLhW2mlpw88C2FpPmUZPuquivJY4CLgWdV1c/GHZckLUS2mVpIHLcjDc85SXYAtgbeYUMvSRtlm6kFwx5kSZIkqccxyJIkSVKPCbIkSZLUY4IsSZIk9ZggS5IkST0myJIkSVLP/wXaqhDAfkdZFQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 720x252 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "d_len = [len(tokenizer.encode(s)) for s in dataset_samsum[\"train\"][\"dialogue\"]]\n",
    "s_len = [len(tokenizer.encode(s)) for s in dataset_samsum[\"train\"][\"summary\"]]\n",
    "fig, axes = plt.subplots(1, 2, figsize=(10, 3.5), sharey=True)\n",
    "axes[0].hist(d_len, bins=20, color=\"C0\", edgecolor=\"C0\")\n",
    "axes[0].set_title(\"Dialogue Token Length\")\n",
    "axes[0].set_xlabel(\"Length\")\n",
    "axes[0].set_ylabel(\"Count\")\n",
    "axes[1].hist(s_len, bins=20, color=\"C0\", edgecolor=\"C0\")\n",
    "axes[1].set_title(\"Summary Token Length\")\n",
    "axes[1].set_xlabel(\"Length\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7d3d42a",
   "metadata": {},
   "source": [
    "We convert the input to tokens. We set max token length for dialogues and summaries to 1024 and 128."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "59024f66",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f4930a5bb7a44e3a468cba309e85670",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/819 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\USER\\anaconda3\\lib\\site-packages\\transformers\\tokenization_utils_base.py:4016: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "def convert_examples_to_features(example_batch):\n",
    "    input_encodings = tokenizer(example_batch[\"dialogue\"], max_length=1024,truncation=True)\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        target_encodings = tokenizer(example_batch[\"summary\"], max_length=128,\n",
    "        truncation=True)\n",
    "        return {\"input_ids\": input_encodings[\"input_ids\"],\n",
    "        \"attention_mask\": input_encodings[\"attention_mask\"],\n",
    "        \"labels\": target_encodings[\"input_ids\"]}\n",
    "dataset_samsum_pt = dataset_samsum.map(convert_examples_to_features,\n",
    "batched=True)\n",
    "columns = [\"input_ids\", \"labels\", \"attention_mask\"]\n",
    "dataset_samsum_pt.set_format(type=\"torch\", columns=columns)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13c91cb9",
   "metadata": {},
   "source": [
    "We define the data collator. The data collator collects all tensors from the batch and stacks them up. It shifts the summary labels at the decoder side by 1 to implement \"teacher forcing\". In addition, ignoring of padding tokens is implemented."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "bfd5a8ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DataCollatorForSeq2Seq\n",
    "\n",
    "seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e450d16d",
   "metadata": {},
   "source": [
    "We now carry out training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "95cde7d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "training_args = TrainingArguments(\n",
    "output_dir='pegasus-samsum', num_train_epochs=1, warmup_steps=500,\n",
    "per_device_train_batch_size=1, per_device_eval_batch_size=1,\n",
    "weight_decay=0.01, logging_steps=10, push_to_hub=True,\n",
    "evaluation_strategy='steps', eval_steps=500, save_steps=1e6,\n",
    "gradient_accumulation_steps=16)\n",
    "# from transformers import TrainingArguments, Trainer\n",
    "# training_args = TrainingArguments(\n",
    "# output_dir='pegasus-samsum', num_train_epochs=1, warmup_steps=500,\n",
    "# per_device_train_batch_size=1,\n",
    "# weight_decay=0.01, logging_steps=10, push_to_hub=True,\n",
    "# evaluation_strategy='steps', eval_steps=500, save_steps=1e6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b7a0c52b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aab1358dc9a2498697e1e36e34fef8c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "579d6c18",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model, args=training_args,\n",
    "tokenizer=tokenizer, data_collator=seq2seq_data_collator,\n",
    "train_dataset=dataset_samsum_pt[\"train\"],\n",
    "eval_dataset=dataset_samsum_pt[\"validation\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "ad56f672",
   "metadata": {},
   "outputs": [
    {
     "ename": "OutOfMemoryError",
     "evalue": "CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 4.00 GiB total capacity; 3.43 GiB already allocated; 0 bytes free; 3.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mOutOfMemoryError\u001b[0m                          Traceback (most recent call last)",
      "Input \u001b[1;32mIn [42]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      2\u001b[0m score \u001b[38;5;241m=\u001b[39m evaluate_summaries_pegasus(\n\u001b[0;32m      3\u001b[0m dataset_samsum[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest\u001b[39m\u001b[38;5;124m\"\u001b[39m], rouge_metric, trainer\u001b[38;5;241m.\u001b[39mmodel, tokenizer,\n\u001b[0;32m      4\u001b[0m batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, column_text\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdialogue\u001b[39m\u001b[38;5;124m\"\u001b[39m, column_summary\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msummary\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      5\u001b[0m rouge_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m((rn, score[rn]\u001b[38;5;241m.\u001b[39mmid\u001b[38;5;241m.\u001b[39mfmeasure) \u001b[38;5;28;01mfor\u001b[39;00m rn \u001b[38;5;129;01min\u001b[39;00m rouge_names)\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\transformers\\trainer.py:1923\u001b[0m, in \u001b[0;36mTrainer.train\u001b[1;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[0;32m   1920\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m   1921\u001b[0m     \u001b[38;5;66;03m# Disable progress bars when uploading models during checkpoints to avoid polluting stdout\u001b[39;00m\n\u001b[0;32m   1922\u001b[0m     hf_hub_utils\u001b[38;5;241m.\u001b[39mdisable_progress_bars()\n\u001b[1;32m-> 1923\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1924\u001b[0m \u001b[43m        \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1925\u001b[0m \u001b[43m        \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1926\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1927\u001b[0m \u001b[43m        \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1928\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1929\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m   1930\u001b[0m     hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\transformers\\trainer.py:2268\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[1;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[0;32m   2265\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[0;32m   2267\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[1;32m-> 2268\u001b[0m     tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   2270\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[0;32m   2271\u001b[0m     args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[0;32m   2272\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[0;32m   2273\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[0;32m   2274\u001b[0m ):\n\u001b[0;32m   2275\u001b[0m     \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[0;32m   2276\u001b[0m     tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\transformers\\trainer.py:3324\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[1;34m(***failed resolving arguments***)\u001b[0m\n\u001b[0;32m   3322\u001b[0m         scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m   3323\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 3324\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mbackward(loss, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m   3326\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\accelerate\\accelerator.py:2134\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[1;34m(self, loss, **kwargs)\u001b[0m\n\u001b[0;32m   2132\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlomo_backward(loss, learning_rate)\n\u001b[0;32m   2133\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 2134\u001b[0m     loss\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    478\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m    479\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m    480\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    485\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m    486\u001b[0m     )\n\u001b[1;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m    488\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[0;32m    489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\autograd\\__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    195\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m    197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[0;32m    198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m    199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[0;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m    202\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 4.00 GiB total capacity; 3.43 GiB already allocated; 0 bytes free; 3.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF"
     ]
    }
   ],
   "source": [
    "trainer.train()\n",
    "score = evaluate_summaries_pegasus(\n",
    "dataset_samsum[\"test\"], rouge_metric, trainer.model, tokenizer,\n",
    "batch_size=2, column_text=\"dialogue\", column_summary=\"summary\")\n",
    "rouge_dict = dict((rn, score[rn].mid.fmeasure) for rn in rouge_names)\n",
    "pd.DataFrame(rouge_dict, index=[f\"pegasus\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7868de6a",
   "metadata": {},
   "source": [
    "The training could not be implemented due to lack of GPU resources. However, it is expected that there would be an improvement in the ROUGE score of the model that was fine-tuned on the SAMSum dataset, over the PEGASUS model before fine-tuning."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
